{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f4b72fd7",
   "metadata": {},
   "source": [
    "## Description:\n",
    "这个是sharedBottom模型的demo, 尝试在中级API的基础上，加一些loss优化的思路， 这次是GradNorm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "324f6ca1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:\n",
      "DeepCTR version 0.9.0 detected. Your version is 0.8.2.\n",
      "Use `pip install -U deepctr` to upgrade.Changelog: https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From C:\\Users\\ZHONGQ~1\\AppData\\Local\\Temp/ipykernel_63772/2514080491.py:25: The name tf.keras.backend.set_session is deprecated. Please use tf.compat.v1.keras.backend.set_session instead.\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From C:\\Users\\ZHONGQ~1\\AppData\\Local\\Temp/ipykernel_63772/2514080491.py:25: The name tf.keras.backend.set_session is deprecated. Please use tf.compat.v1.keras.backend.set_session instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "from sklearn.metrics import roc_auc_score, mean_squared_error, mean_absolute_error\n",
    "from deepctr.feature_column import SparseFeat, VarLenSparseFeat, DenseFeat\n",
    "from deepctr.feature_column import get_feature_names\n",
    "from SharedBottom import SharedBottom\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "config = tf.compat.v1.ConfigProto()\n",
    "config.gpu_options.allow_growth = True      # TensorFlow按需分配显存\n",
    "config.gpu_options.per_process_gpu_memory_fraction = 0.5  # 指定显存分配比例\n",
    "tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))\n",
    "\n",
    "# tf.config.experimental_run_functions_eagerly(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "577efc09",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = '../data_process'\n",
    "data = pd.read_csv(os.path.join(data_path, 'train_data.csv'), index_col=0, parse_dates=['expo_time'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "88a84ee4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 选择出需要用到的列\n",
    "use_cols = ['user_id', 'article_id', 'expo_time', 'net_status', 'exop_position', 'duration', 'device', 'city', 'age', 'gender', 'img_num', 'cat_1', 'click']\n",
    "data_new = data[use_cols]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ce2ea472",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 由于这个data_new的数据量还是太大， 我电脑训练不动， 所以这里再进行一波抽样\n",
    "users = set(data_new['user_id'])\n",
    "sampled_users = random.sample(users, 1000)\n",
    "data_new = data_new[data_new['user_id'].isin(sampled_users)]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df0b628c",
   "metadata": {},
   "source": [
    "## 数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e380f114",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 处理img_num\n",
    "def transform(x):\n",
    "    if x == '上海':\n",
    "        return 0\n",
    "    elif isinstance(x, float):\n",
    "        return float(x)\n",
    "    else:\n",
    "        return float(eval(x))\n",
    "data_new['img_num'] = data_new['img_num'].apply(lambda x: transform(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "46f77eb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_id_raw = data_new[['user_id']].drop_duplicates('user_id')\n",
    "doc_id_raw = data_new[['article_id']].drop_duplicates('article_id')\n",
    "\n",
    "# 简单数据预处理\n",
    "sparse_features = [\n",
    "    'user_id', 'article_id', 'net_status', 'exop_position', 'device', 'city', 'age', 'gender', 'cat_1'\n",
    "]\n",
    "dense_features = [\n",
    "    'img_num'\n",
    "]\n",
    "\n",
    "# 填充缺失值\n",
    "data_new[sparse_features] = data_new[sparse_features].fillna('-1')\n",
    "data_new[dense_features] = data_new[dense_features].fillna(0)\n",
    "\n",
    "# 归一化\n",
    "mms = MinMaxScaler(feature_range=(0, 1))\n",
    "data_new[dense_features] = mms.fit_transform(data_new[dense_features])\n",
    "\n",
    "feature_max_idx = {}\n",
    "for feat in sparse_features:\n",
    "    lbe = LabelEncoder()\n",
    "    data_new[feat] = lbe.fit_transform(data_new[feat])\n",
    "    feature_max_idx[feat] = data_new[feat].max() + 1000\n",
    "\n",
    "# 构建用户id词典和doc的id词典，方便从用户idx找到原始的id\n",
    "# user_id_enc = data[['user_id']].drop_duplicates('user_id')\n",
    "# doc_id_enc = data[['article_id']].drop_duplicates('article_id')\n",
    "# user_idx_2_rawid = dict(zip(user_id_enc['user_id'], user_id_raw['user_id']))\n",
    "# doc_idx_2_rawid = dict(zip(doc_id_enc['article_id'], doc_id_raw['article_id']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ad2691ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 划分数据集  这里按照曝光时间划分\n",
    "train_data = data_new[data_new['expo_time'] < '2021-07-06']\n",
    "test_data = data_new[data_new['expo_time'] >= '2021-07-06']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad199918",
   "metadata": {},
   "source": [
    "## 特征封装"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d0e5ea3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "sparse_feature_columns = [SparseFeat(feat, feature_max_idx[feat], embedding_dim=4) for feat in sparse_features]\n",
    "Dense_feature_columns = [DenseFeat(feat, 1) for feat in dense_features]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c9f538ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 划分dnn和linear特征\n",
    "dnn_features_columns = sparse_feature_columns + Dense_feature_columns\n",
    "lhuc_feature_columns = sparse_feature_columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "f5355501",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_names = get_feature_names(dnn_features_columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c6a5f286",
   "metadata": {},
   "outputs": [],
   "source": [
    "# AttributeError: 'numpy.dtype[int64]' object has no attribute 'base_dtype' \n",
    "# Keras需要把输入声明为Keras张量，其他的比如numpy张量作为输入不好使\n",
    "train_model_input = {name: tf.keras.backend.constant(train_data[name]) for name in feature_names}\n",
    "test_model_input = {name: tf.keras.backend.constant(test_data[name]) for name in feature_names}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e4803b7",
   "metadata": {},
   "source": [
    "## 模型搭建"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9b0270fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SharedBottom(dnn_features_columns, lhuc_feature_columns, tower_dnn_hidden_units=[], task_types=['regression', 'binary'], \n",
    "             task_names=['duration', 'click'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "df05fadd",
   "metadata": {},
   "outputs": [],
   "source": [
    "#model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb7b39f2",
   "metadata": {},
   "source": [
    "## 模型的训练和预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4f373ebd",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_duration = tf.keras.backend.constant(train_data['duration'].values)\n",
    "label_click = tf.keras.backend.constant(train_data['click'].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "01764354",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建数据管道\n",
    "train_ds = tf.data.Dataset.from_tensor_slices((train_model_input, (label_duration, label_click))).shuffle(buffer_size=100).batch(128).prefetch(tf.data.experimental.AUTOTUNE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e9d77f25",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模型训练这里，需要用到底层的训练脚本，这里不能用高层keras的API\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=0.005)\n",
    "\n",
    "train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
    "train_reg_loss = tf.keras.metrics.Mean(name='train_reg_loss')\n",
    "train_bin_loss = tf.keras.metrics.Mean(name='train_bin_loss')\n",
    "loss_func = {\"binary\": tf.keras.losses.binary_crossentropy, \"regression\": tf.keras.losses.mean_squared_error}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ae2f0ef9",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def train_step(features, labels, task_types, weight, l01, l02, grad_norm=False):\n",
    "    losses = []\n",
    "    gnorms = []\n",
    "    \n",
    "    # RuntimeError: GradientTape.gradient can only be called once on non-persistent tapes\n",
    "    # 这是因为GradientTape 占用的资源默认情况下dw = t.gradient(loss, w)计算完毕就会立即释放\n",
    "    # 如果连续计算微分， 指定persistent=True\n",
    "    with tf.GradientTape(persistent=True) as tape:\n",
    "        # 遍历每个任务\n",
    "        for i, task_type in enumerate(task_types):\n",
    "            out = model(features, training=True)\n",
    "            task_loss = loss_func[task_types[i]](out[i], labels[i])\n",
    "            # print(\"task_loss\", task_loss)\n",
    "            losses.append(weight[i] * task_loss)\n",
    "                \n",
    "        # 这里更新\n",
    "        loss = tf.add_n(losses)\n",
    "        gradients = tape.gradient(loss, model.trainable_variables)\n",
    "        \n",
    "        # print('这里是loss', loss)\n",
    "        # 如果使用grad_norm\n",
    "        if grad_norm:\n",
    "            \n",
    "            # 第一步： 拿到每个任务对于最后一个共享层的梯度\n",
    "            # # 获取到loss对最后一层共享层的梯度  这里需要对最后一个共享层参数计算一遍微分\n",
    "            G1R = tape.gradient(losses[0], model.get_layer('sharedlast').trainable_variables)[0]  # 这里只用w， 不用b\n",
    "            G1 = tf.norm(G1R, ord=2)   # 求二范数\n",
    "            G2R = tape.gradient(losses[1], model.get_layer('sharedlast').trainable_variables)[0]\n",
    "            G2 = tf.norm(G2R, ord=2)  \n",
    "            \n",
    "            # 第二步： 计算平均梯度\n",
    "            G_avg = tf.math.divide(tf.add(G1, G2), 2)\n",
    "            \n",
    "            # 第三步： L_hat_i 表示当前任务训练程度\n",
    "            l_hat_1 = tf.math.divide(tf.keras.backend.mean(losses[0]), l01)\n",
    "            l_hat_2 = tf.math.divide(tf.keras.backend.mean(losses[1]), l02)\n",
    "            l_hat_avg = tf.math.divide(tf.math.add(l_hat_1, l_hat_2), 2)\n",
    "            \n",
    "            # Inverse training rates r_i(t)   tf2.x 不能tf.div， 移除了这个函数\n",
    "            inv_rate_1, inv_rate_2 = tf.math.divide(l_hat_1, l_hat_avg), tf.math.divide(l_hat_2, l_hat_avg)\n",
    "            \n",
    "            # 放大系数alpha\n",
    "            a = tf.constant(0.5)\n",
    "            C1 = tf.multiply(G_avg, tf.pow(inv_rate_1, a))\n",
    "            C2 = tf.multiply(G_avg, tf.pow(inv_rate_2, a))\n",
    "            # 看成常数， 不计算梯度\n",
    "            C1 = tf.stop_gradient(tf.identity(C1))\n",
    "            C2 = tf.stop_gradient(tf.identity(C2))\n",
    "            \n",
    "            # 第五步： 定义grad_loss\n",
    "            loss_gradnorm = tf.math.add(\n",
    "                tf.reduce_sum(tf.abs(tf.subtract(G1, C1))),\n",
    "                tf.reduce_sum(tf.abs(tf.subtract(G2, C2))))\n",
    "            \n",
    "            # 第六步： 求权重的梯度\n",
    "            weight1_grad = tape.gradient(loss_gradnorm, weight[0])\n",
    "            weight2_grad = tape.gradient(loss_gradnorm, weight[1])\n",
    "            weight_grads = [weight1_grad, weight2_grad]\n",
    "            \n",
    "    \n",
    "    # 更新所有W参数\n",
    "    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "    \n",
    "    \n",
    "    train_loss(loss)\n",
    "    train_reg_loss(losses[0])\n",
    "    train_bin_loss(losses[1])\n",
    "    \n",
    "    if grad_norm: \n",
    "        return loss, losses[0], losses[1], weight_grads\n",
    "    else:\n",
    "        return loss, losses[0], losses[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ca4d17f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|                                                                                           | 0/10 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>]\n",
      "WARNING:tensorflow:Calling GradientTape.gradient on a persistent tape inside its context is significantly less efficient than calling it outside the context (it causes the gradient ops to be recorded on the tape, leading to increased CPU and memory usage). Only call GradientTape.gradient inside the context if you actually want to trace the gradient in order to compute higher order derivatives.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Calling GradientTape.gradient on a persistent tape inside its context is significantly less efficient than calling it outside the context (it causes the gradient ops to be recorded on the tape, leading to increased CPU and memory usage). Only call GradientTape.gradient inside the context if you actually want to trace the gradient in order to compute higher order derivatives.\n",
      " 10%|████████▎                                                                          | 1/10 [00:54<08:13, 54.89s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Loss: 6577.0927734375 - regression_loss: 35941.01171875 - binary_loss: 2.3232178688049316\n",
      "[<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.9875725>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0124274>]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "'NoneType' object has no attribute 'shape'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32mD:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py\u001b[0m in \u001b[0;36m_hash_fix\u001b[1;34m(self, elem)\u001b[0m\n\u001b[0;32m    108\u001b[0m     \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 109\u001b[1;33m       \u001b[0mhash\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0melem\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    110\u001b[0m     \u001b[1;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mTypeError\u001b[0m: weak object has gone away",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32mC:\\Users\\ZHONGQ~1\\AppData\\Local\\Temp/ipykernel_63772/1809098899.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     21\u001b[0m             \u001b[0mloss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloss_reg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloss_bin\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrain_step\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfeature\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtask_types\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtask_weight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0ml01\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0ml02\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_norm\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     22\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 23\u001b[1;33m             \u001b[0mloss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloss_reg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloss_bin\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtask_weight_grads\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrain_step\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfeature\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtask_types\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtask_weight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0ml01\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0ml02\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_norm\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     24\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     25\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mgrad_norm\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\eager\\def_function.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *args, **kwds)\u001b[0m\n\u001b[0;32m    455\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    456\u001b[0m     \u001b[0mtracing_count\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_tracing_count\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 457\u001b[1;33m     \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    458\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mtracing_count\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_get_tracing_count\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    459\u001b[0m       \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call_counter\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcalled_without_tracing\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\eager\\def_function.py\u001b[0m in \u001b[0;36m_call\u001b[1;34m(self, *args, **kwds)\u001b[0m\n\u001b[0;32m    485\u001b[0m       \u001b[1;31m# In this case we have created variables on the first call, so we run the\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    486\u001b[0m       \u001b[1;31m# defunned version which is guaranteed to never create variables.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 487\u001b[1;33m       \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_stateless_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwds\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# pylint: disable=not-callable\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    488\u001b[0m     \u001b[1;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_stateful_fn\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    489\u001b[0m       \u001b[1;31m# Release the lock early so that multiple threads can perform the call\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1820\u001b[0m   \u001b[1;32mdef\u001b[0m \u001b[0m__call__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1821\u001b[0m     \u001b[1;34m\"\"\"Calls a graph function specialized to the inputs.\"\"\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1822\u001b[1;33m     \u001b[0mgraph_function\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_maybe_define_function\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1823\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0mgraph_function\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_filtered_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m  \u001b[1;31m# pylint: disable=protected-access\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1824\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py\u001b[0m in \u001b[0;36m_maybe_define_function\u001b[1;34m(self, args, kwargs)\u001b[0m\n\u001b[0;32m   2117\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2118\u001b[0m     \u001b[1;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_lock\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2119\u001b[1;33m       \u001b[0mgraph_function\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_function_cache\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mprimary\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mcache_key\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   2120\u001b[0m       \u001b[1;32mif\u001b[0m \u001b[0mgraph_function\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2121\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mgraph_function\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py\u001b[0m in \u001b[0;36m__eq__\u001b[1;34m(self, other)\u001b[0m\n\u001b[0;32m    115\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    116\u001b[0m   \u001b[1;32mdef\u001b[0m \u001b[0m__eq__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mother\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 117\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_fields_safe\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mother\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_fields_safe\u001b[0m  \u001b[1;31m# pylint: disable=protected-access\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    118\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    119\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py\u001b[0m in \u001b[0;36m_fields_safe\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m     91\u001b[0m   \u001b[1;32mdef\u001b[0m \u001b[0m_fields_safe\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     92\u001b[0m     \u001b[1;34m\"\"\"Hash & equality-safe version of all the namedtuple fields.\"\"\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 93\u001b[1;33m     return (self._hash_fix(self.input_signature), self.parent_graph,\n\u001b[0m\u001b[0;32m     94\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice_functions\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcolocation_stack\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     95\u001b[0m             self.in_cross_replica_context)\n",
      "\u001b[1;32mD:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py\u001b[0m in \u001b[0;36m_hash_fix\u001b[1;34m(self, elem)\u001b[0m\n\u001b[0;32m     99\u001b[0m     \u001b[1;31m# Descend into tuples\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    100\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0melem\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 101\u001b[1;33m       \u001b[1;32mreturn\u001b[0m \u001b[0mtuple\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_hash_fix\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0melem\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    102\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    103\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0melem\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mset\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m     99\u001b[0m     \u001b[1;31m# Descend into tuples\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    100\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0melem\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 101\u001b[1;33m       \u001b[1;32mreturn\u001b[0m \u001b[0mtuple\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_hash_fix\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0melem\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    102\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    103\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0melem\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mset\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py\u001b[0m in \u001b[0;36m_hash_fix\u001b[1;34m(self, elem)\u001b[0m\n\u001b[0;32m     99\u001b[0m     \u001b[1;31m# Descend into tuples\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    100\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0melem\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 101\u001b[1;33m       \u001b[1;32mreturn\u001b[0m \u001b[0mtuple\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_hash_fix\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0melem\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    102\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    103\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0melem\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mset\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[1;34m(.0)\u001b[0m\n\u001b[0;32m     99\u001b[0m     \u001b[1;31m# Descend into tuples\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    100\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0melem\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 101\u001b[1;33m       \u001b[1;32mreturn\u001b[0m \u001b[0mtuple\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_hash_fix\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0melem\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    102\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    103\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0melem\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mset\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\anaconda3\\envs\\tfenv\\lib\\site-packages\\tensorflow_core\\python\\eager\\function.py\u001b[0m in \u001b[0;36m_hash_fix\u001b[1;34m(self, elem)\u001b[0m\n\u001b[0;32m    110\u001b[0m     \u001b[1;32mexcept\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    111\u001b[0m       \u001b[0mv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0melem\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 112\u001b[1;33m       \u001b[1;32mreturn\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtensor_spec\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensorSpec\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    113\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    114\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0melem\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'shape'"
     ]
    }
   ],
   "source": [
    "epochs = 10\n",
    "best_test_loss = float('inf')\n",
    "task_types = [\"regression\", \"binary\"]\n",
    "\n",
    "task_weight = [tf.Variable(1.0, trainable=True), tf.Variable(1.0, trainable=True)]\n",
    "\n",
    "l01 = tf.Variable(tf.math.log(2.), trainable=False)\n",
    "l02 = tf.Variable(tf.math.log(2.), trainable=False)\n",
    "\n",
    "grad_norm = True\n",
    "\n",
    "for epoch in tqdm(range(1, epochs+1)):\n",
    "    \n",
    "    print(task_weight)\n",
    "    train_loss.reset_states()\n",
    "    train_reg_loss.reset_states()\n",
    "    train_bin_loss.reset_states()\n",
    "\n",
    "    for feature, labels in train_ds:\n",
    "        if not grad_norm:\n",
    "            loss, loss_reg, loss_bin = train_step(feature, labels, task_types, task_weight, l01, l02, grad_norm)\n",
    "        else:\n",
    "            loss, loss_reg, loss_bin, task_weight_grads = train_step(feature, labels, task_types, task_weight, l01, l02, grad_norm)\n",
    "\n",
    "    if grad_norm: \n",
    "        # 更新权重参数  \n",
    "        # 这里的一个坑： 这个一定要放到epoch下更新w，不能放到train_step里面，放到里面，相当于每个batch级更新\n",
    "        # 而每个batch差别很大，经过几个batch级别的迭代，这里的loss就会变成nan， 一定要放到外面\n",
    "        optimizer.apply_gradients(zip(task_weight_grads, task_weight))\n",
    "        \n",
    "        # 如果两者某一个出现了nan\n",
    "        if tf.compat.v1.is_nan(task_weight[0]) or tf.compat.v1.is_nan(task_weight[1]):\n",
    "            task_weight = [tf.Variable(1.0, trainable=True), tf.Variable(1.0, trainable=True)]\n",
    "        \n",
    "        else:\n",
    "            #weight参数需要renormalize下   这里如果不renormalize， 更新完的梯度会有nan值，此时会造成loss直接变成nan\n",
    "            coef = tf.math.divide(2.0, tf.add(task_weight[0], task_weight[1]))\n",
    "\n",
    "            task_weight[0] = tf.Variable(tf.multiply(task_weight[0], coef))\n",
    "            task_weight[1] = tf.Variable(tf.multiply(task_weight[1], coef))\n",
    "        \n",
    "        \n",
    "    template = 'Epoch {}, Loss: {} - regression_loss: {} - binary_loss: {}'\n",
    "    print(template.format(epoch, train_loss.result(), \n",
    "                          np.mean(loss_reg), \n",
    "                         np.mean(loss_bin)))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d671e57",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_ans = model.predict(test_model_input, batch_size=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32dc4aa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"test click AUC\", round(roc_auc_score(test_data['click'], pred_ans[1]), 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c377e25",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"test duration\", round(mean_absolute_error(test_data['duration'], pred_ans[0]), 4))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
